{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create entry points to spark\n",
    "try:\n",
    "    sc.stop()\n",
    "except:\n",
    "    pass\n",
    "from pyspark import SparkContext, SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "sc=SparkContext()\n",
    "spark = SparkSession(sparkContext=sc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `udf()` function and sql types"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `pyspark.sql.functions.udf()` function is a very important function. It allows us to transfer a **user defined function** to a **`pyspark.sql.functions`** function which can act on columns of a DataFrame. It makes data framsformation much more flexible.\n",
    "\n",
    "Using `udf()` could be tricky. The key is to understand how to define the `returnType` parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.types import *\n",
    "from pyspark.sql.functions import udf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+\n",
      "|            model| mpg|cyl| disp| hp|drat|   wt| qsec| vs| am|gear|carb|\n",
      "+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+\n",
      "|        Mazda RX4|21.0|  6|160.0|110| 3.9| 2.62|16.46|  0|  1|   4|   4|\n",
      "|    Mazda RX4 Wag|21.0|  6|160.0|110| 3.9|2.875|17.02|  0|  1|   4|   4|\n",
      "|       Datsun 710|22.8|  4|108.0| 93|3.85| 2.32|18.61|  1|  1|   4|   1|\n",
      "|   Hornet 4 Drive|21.4|  6|258.0|110|3.08|3.215|19.44|  1|  0|   3|   1|\n",
      "|Hornet Sportabout|18.7|  8|360.0|175|3.15| 3.44|17.02|  0|  0|   3|   2|\n",
      "+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "mtcars = spark.read.csv('../../data/mtcars.csv', inferSchema=True, header=True)\n",
    "mtcars = mtcars.withColumnRenamed('_c0', 'model')\n",
    "mtcars.show(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**The structure of the schema passed to `returnType` has to match the data structure of the return value from the user defined function**."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Case 1**: divide **disp** by **hp** and put the result to a new column\n",
    "\n",
    "The user defined function returns a float value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def disp_by_hp(disp, hp):\n",
    "    return(disp/hp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "disp_by_hp_udf = udf(disp_by_hp, returnType=FloatType())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Column<b'model'>,\n",
       " Column<b'mpg'>,\n",
       " Column<b'cyl'>,\n",
       " Column<b'disp'>,\n",
       " Column<b'hp'>,\n",
       " Column<b'drat'>,\n",
       " Column<b'wt'>,\n",
       " Column<b'qsec'>,\n",
       " Column<b'vs'>,\n",
       " Column<b'am'>,\n",
       " Column<b'gear'>,\n",
       " Column<b'carb'>]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_original_cols = [eval('mtcars.' + x) for x in mtcars.columns]\n",
    "all_original_cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Column<b'disp_by_hp(disp, hp)'>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disp_by_hp_col = disp_by_hp_udf(mtcars.disp, mtcars.hp)\n",
    "disp_by_hp_col"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Column<b'model'>,\n",
       " Column<b'mpg'>,\n",
       " Column<b'cyl'>,\n",
       " Column<b'disp'>,\n",
       " Column<b'hp'>,\n",
       " Column<b'drat'>,\n",
       " Column<b'wt'>,\n",
       " Column<b'qsec'>,\n",
       " Column<b'vs'>,\n",
       " Column<b'am'>,\n",
       " Column<b'gear'>,\n",
       " Column<b'carb'>,\n",
       " Column<b'disp_by_hp(disp, hp)'>]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_new_cols = all_original_cols + [disp_by_hp_col]\n",
    "all_new_cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+----+---+-----+---+----+-----+-----+---+---+----+----+--------------------+\n",
      "|              model| mpg|cyl| disp| hp|drat|   wt| qsec| vs| am|gear|carb|disp_by_hp(disp, hp)|\n",
      "+-------------------+----+---+-----+---+----+-----+-----+---+---+----+----+--------------------+\n",
      "|          Mazda RX4|21.0|  6|160.0|110| 3.9| 2.62|16.46|  0|  1|   4|   4|           1.4545455|\n",
      "|      Mazda RX4 Wag|21.0|  6|160.0|110| 3.9|2.875|17.02|  0|  1|   4|   4|           1.4545455|\n",
      "|         Datsun 710|22.8|  4|108.0| 93|3.85| 2.32|18.61|  1|  1|   4|   1|           1.1612903|\n",
      "|     Hornet 4 Drive|21.4|  6|258.0|110|3.08|3.215|19.44|  1|  0|   3|   1|           2.3454545|\n",
      "|  Hornet Sportabout|18.7|  8|360.0|175|3.15| 3.44|17.02|  0|  0|   3|   2|            2.057143|\n",
      "|            Valiant|18.1|  6|225.0|105|2.76| 3.46|20.22|  1|  0|   3|   1|            2.142857|\n",
      "|         Duster 360|14.3|  8|360.0|245|3.21| 3.57|15.84|  0|  0|   3|   4|           1.4693878|\n",
      "|          Merc 240D|24.4|  4|146.7| 62|3.69| 3.19| 20.0|  1|  0|   4|   2|            2.366129|\n",
      "|           Merc 230|22.8|  4|140.8| 95|3.92| 3.15| 22.9|  1|  0|   4|   2|           1.4821053|\n",
      "|           Merc 280|19.2|  6|167.6|123|3.92| 3.44| 18.3|  1|  0|   4|   4|           1.3626016|\n",
      "|          Merc 280C|17.8|  6|167.6|123|3.92| 3.44| 18.9|  1|  0|   4|   4|           1.3626016|\n",
      "|         Merc 450SE|16.4|  8|275.8|180|3.07| 4.07| 17.4|  0|  0|   3|   3|           1.5322223|\n",
      "|         Merc 450SL|17.3|  8|275.8|180|3.07| 3.73| 17.6|  0|  0|   3|   3|           1.5322223|\n",
      "|        Merc 450SLC|15.2|  8|275.8|180|3.07| 3.78| 18.0|  0|  0|   3|   3|           1.5322223|\n",
      "| Cadillac Fleetwood|10.4|  8|472.0|205|2.93| 5.25|17.98|  0|  0|   3|   4|            2.302439|\n",
      "|Lincoln Continental|10.4|  8|460.0|215| 3.0|5.424|17.82|  0|  0|   3|   4|            2.139535|\n",
      "|  Chrysler Imperial|14.7|  8|440.0|230|3.23|5.345|17.42|  0|  0|   3|   4|           1.9130435|\n",
      "|           Fiat 128|32.4|  4| 78.7| 66|4.08|  2.2|19.47|  1|  1|   4|   1|           1.1924243|\n",
      "|        Honda Civic|30.4|  4| 75.7| 52|4.93|1.615|18.52|  1|  1|   4|   2|           1.4557692|\n",
      "|     Toyota Corolla|33.9|  4| 71.1| 65|4.22|1.835| 19.9|  1|  1|   4|   1|           1.0938462|\n",
      "+-------------------+----+---+-----+---+----+-----+-----+---+---+----+----+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "mtcars.select(all_new_cols).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**case 2**: create an array column that contain **disp** and **hp** values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define function\n",
    "def merge_two_columns(col1, col2):\n",
    "    return([float(col1), float(col2)])\n",
    "\n",
    "# convert user defined function into an udf function (sql function)\n",
    "array_merge_two_columns_udf = udf(merge_two_columns, returnType=ArrayType(FloatType()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Column<b'merge_two_columns(disp, hp)'>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "array_col = array_merge_two_columns_udf(mtcars.disp, mtcars.hp)\n",
    "array_col"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Column<b'model'>,\n",
       " Column<b'mpg'>,\n",
       " Column<b'cyl'>,\n",
       " Column<b'disp'>,\n",
       " Column<b'hp'>,\n",
       " Column<b'drat'>,\n",
       " Column<b'wt'>,\n",
       " Column<b'qsec'>,\n",
       " Column<b'vs'>,\n",
       " Column<b'am'>,\n",
       " Column<b'gear'>,\n",
       " Column<b'carb'>,\n",
       " Column<b'merge_two_columns(disp, hp)'>]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_new_cols = all_original_cols + [array_col]\n",
    "all_new_cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+---------------------------+\n",
      "|model            |mpg |cyl|disp |hp |drat|wt   |qsec |vs |am |gear|carb|merge_two_columns(disp, hp)|\n",
      "+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+---------------------------+\n",
      "|Mazda RX4        |21.0|6  |160.0|110|3.9 |2.62 |16.46|0  |1  |4   |4   |[160.0, 110.0]             |\n",
      "|Mazda RX4 Wag    |21.0|6  |160.0|110|3.9 |2.875|17.02|0  |1  |4   |4   |[160.0, 110.0]             |\n",
      "|Datsun 710       |22.8|4  |108.0|93 |3.85|2.32 |18.61|1  |1  |4   |1   |[108.0, 93.0]              |\n",
      "|Hornet 4 Drive   |21.4|6  |258.0|110|3.08|3.215|19.44|1  |0  |3   |1   |[258.0, 110.0]             |\n",
      "|Hornet Sportabout|18.7|8  |360.0|175|3.15|3.44 |17.02|0  |0  |3   |2   |[360.0, 175.0]             |\n",
      "+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+---------------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "mtcars.select(all_new_cols).show(5, truncate=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `ArrayType` vs. `StructType`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Both `ArrayType` and `StructType` can be used to build `returnType` for a list. The difference is: \n",
    "\n",
    "1. `ArrayType` requires all elements in the list have the same `elementType`, while `StructType` can have different `elementTypes`.\n",
    "2. `StructType` represents a `Row` object."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Define an `ArrayType` with elementType being `FloatType`.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define function\n",
    "def merge_two_columns(col1, col2):\n",
    "    return([float(col1), float(col2)])\n",
    "array_type = ArrayType(FloatType())\n",
    "array_merge_two_columns_udf = udf(merge_two_columns, returnType=array_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Define a `StructType` with one elementType being `StringType` and the other being `FloatType`.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define function\n",
    "def merge_two_columns(col1, col2):\n",
    "    return([str(col1), float(col2)])\n",
    "struct_type = StructType([\n",
    "    StructField('f1', StringType()),\n",
    "    StructField('f2', FloatType())\n",
    "])\n",
    "struct_merge_two_columns_udf = udf(merge_two_columns, returnType=struct_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**array column** expression: both values are float type values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Column<b'merge_two_columns(hp, disp)'>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "array_col = array_merge_two_columns_udf(mtcars.hp, mtcars.disp)\n",
    "array_col"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**struct column** expression: first value is a string and the second value is a float type value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Column<b'merge_two_columns(model, disp)'>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "struct_col = struct_merge_two_columns_udf(mtcars.model, mtcars.disp)\n",
    "struct_col"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Results**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------------------------+------------------------------+\n",
      "|merge_two_columns(hp, disp)|merge_two_columns(model, disp)|\n",
      "+---------------------------+------------------------------+\n",
      "|[110.0, 160.0]             |[Mazda RX4, 160.0]            |\n",
      "|[110.0, 160.0]             |[Mazda RX4 Wag, 160.0]        |\n",
      "|[93.0, 108.0]              |[Datsun 710, 108.0]           |\n",
      "|[110.0, 258.0]             |[Hornet 4 Drive, 258.0]       |\n",
      "|[175.0, 360.0]             |[Hornet Sportabout, 360.0]    |\n",
      "|[105.0, 225.0]             |[Valiant, 225.0]              |\n",
      "|[245.0, 360.0]             |[Duster 360, 360.0]           |\n",
      "|[62.0, 146.7]              |[Merc 240D, 146.7]            |\n",
      "|[95.0, 140.8]              |[Merc 230, 140.8]             |\n",
      "|[123.0, 167.6]             |[Merc 280, 167.6]             |\n",
      "|[123.0, 167.6]             |[Merc 280C, 167.6]            |\n",
      "|[180.0, 275.8]             |[Merc 450SE, 275.8]           |\n",
      "|[180.0, 275.8]             |[Merc 450SL, 275.8]           |\n",
      "|[180.0, 275.8]             |[Merc 450SLC, 275.8]          |\n",
      "|[205.0, 472.0]             |[Cadillac Fleetwood, 472.0]   |\n",
      "|[215.0, 460.0]             |[Lincoln Continental, 460.0]  |\n",
      "|[230.0, 440.0]             |[Chrysler Imperial, 440.0]    |\n",
      "|[66.0, 78.7]               |[Fiat 128, 78.7]              |\n",
      "|[52.0, 75.7]               |[Honda Civic, 75.7]           |\n",
      "|[65.0, 71.1]               |[Toyota Corolla, 71.1]        |\n",
      "+---------------------------+------------------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "mtcars.select(array_col, struct_col).show(truncate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
